Open In Colab

Tensor Puzzles - Penzai Edition

This is a version of the tensor puzzles implemented the JAX Penzai library.

Penzai is a really nice fit for these puzzles both because it comes with a really clean visualization library built-in and because it has a very nice named-tensor implementation.

I recommend running in Colab. Click here and copy the notebook to get start.

#!pip install -qqq jaxtyping hypothesis pytest penzai
import jax.numpy as np
import numpy as onp
from penzai import pz
arange = pz.nx.arange
where = pz.nx.nmap(np.where)
wrap = pz.nx.wrap
pz.ts.register_as_default()

# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer(force_continuous=True, around_zero=True,  prefers_column=["j"], prefers_row=["i"]))
import inspect
import random
from jaxtyping import Int32
NamedArray = pz.nx.NamedArray
def make_test(name, problem, problem_spec, add_sizes=[],
              init_size = {},
              constraint=lambda d: d):
    args = {}
    signature = inspect.signature(problem)
    for n, p in signature.parameters.items():
        args[n] = [d.name for d in p.annotation.dims]
    args["return"] = [d.name for d in signature.return_annotation.dims]

    def make_instance():
        example = {}
        reg = {}
        sizes = {}
        for k in init_size:
            sizes[k] = init_size[k]
        for n in args:
            size = {}
            for name in args[n]:
                if name[0] not in sizes:
                    sizes[name[0]] = random.randint(2, 7)
                size[name] = sizes[name[0]]
            if "_s" in n:
                l = list(size.keys())[0]
                example[n] = pz.nx.arange(l, size[l])
            else:
                v = onp.random.randint(-5, 5, list(size.values()))
                example[n] = pz.nx.wrap(v).tag(*args[n])
        example = constraint(example)
        for n in args:
            x = example[n]
            x = x.untag(*args[n])
            reg[n] = x.unwrap().tolist()
            if len(args[n]) == 0:
                reg[n] = [0]
        return example, reg

    examples = []
    correct = 0
    for i in range(3):
        example, reg = make_instance()
        # out = example["return"].tolist()
        del example["return"]
        problem_spec(*reg.values())
        if len(reg["return"]) == 1:
            reg["return"] = reg["return"][0]
        yours = None
        yours = problem(**example)
        example["target"] = wrap(reg["return"])
        example["target"] = example["target"].tag(*args["return"])
        if yours is not None:
            example["yours"] = yours
        same = example["target"] == example["yours"]
        if same.untag(*same.named_shape.keys()).unwrap().all():
            correct += 1
        examples.append(example)
    if correct == 3:
        print("Correct")
    else:
        print("Failure")
    return examples

Rules

  1. Each puzzle needs to be solved in 1 line (<80 columns) of code.
  2. You are only allowed to use contract, where and indexing.
  3. You are not allowed anything else. No view, sum, take, squeeze, tensor.
# Example of named infix ops.
a = [arange("i", 4), arange("j", 5)]
b = [arange("j", 3), arange("i", 2)]

[{"a": a, "b":b, "ret": a + b} for a, b in zip(a, b)]
(Loading...)
# Example of where
examples = [(wrap([False, True], "i"), wrap([1, 1], "i"), wrap([-1, 0], "i")),
            (wrap([[False, True], [True, False]], "i", "j"), wrap([0, 1], "i"), wrap([-1, 0], "j")),
           ]
[{"q": q, "a":a, "b":b, "ret": where(q, a, b)} for q, a, b in examples]
(Loading...)
# Example of contraction
def contract(n, *ts):
    t = 1
    for t2 in ts:
        t = t * t2
    return pz.nx.nmap(np.sum)(t.untag(n))

a = [arange("i", 4), arange("j", 5)]
b = [arange("j", 3), arange("i", 2)]

[{"a": a, "b":b, "ret": contract("i", a * b)} for a, b in zip(a, b)]
(Loading...)

Puzzle 1 - ones

Compute ones - the vector of all ones.

def ones_spec(i_s, out):
    for i in i_s:
        out[i] = 1

def ones(i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
    return i_s * 0 + 1

make_test("one", ones, ones_spec)
Correct
(Loading...)

Puzzle 2 - sum

Compute sum - the sum of a vector.

def sum_spec(a, out):
    for i in range(len(a)):
        out[0] = out[0] + a[i]

def sum(a: Int32[NamedArray, "i"]) -> Int32[NamedArray, ""]:
    return contract("i", a)

make_test("sum", sum, sum_spec)
Correct
(Loading...)

Puzzle 3 - outer

Compute outer - the outer product of two vectors.

def outer_spec(a, b, out):
    for i in range(len(out)):
        for j in range(len(out[0])):
            out[i][j] = a[i] * b[j]

def outer(a: Int32[NamedArray, "i"], b : Int32[NamedArray, "j"]) -> Int32[NamedArray, "i j"]:
    return a * b

make_test("outer", outer, outer_spec)
Correct
(Loading...)

Puzzle 4 - diag

Compute diag - the diagonal vector of a square matrix.

def diag_spec(a, i1_s, out):
    for i in range(len(a)):
        out[i] = a[i][i]

def diag(a: Int32[NamedArray, "i1 i2"], i1_s: Int32[NamedArray, "i1"]) -> Int32[NamedArray, "i1"]:
    return a[{"i1": i1_s, "i2": i1_s}]


make_test("diag", diag, diag_spec)
Correct
(Loading...)

Puzzle 5 - eye

Compute eye - the identity matrix.

def eye_spec(i1_s, i2_s, out):
    for i in i1_s:
        for j in i2_s:
            if i == j:
                out[i][j] = 1
            else:
                out[i][j] = 0

def eye(i1_s: Int32[NamedArray, "i1"], i2_s : Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i1 i2"]:
    return where(i1_s == i2_s, 1, 0)


make_test("eye", eye, eye_spec)
Correct
(Loading...)

Puzzle 6 - triu

Compute triu - the upper triangular matrix.

def triu_spec(i1_s, i2_s, out):
    for i in i1_s:
        for j in i2_s:
            if i <= j:
                out[i][j] = 1
            else:
                out[i][j] = 0

def triu(i1_s: Int32[NamedArray, "i1"], i2_s : Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i1 i2"]:
    return where(i1_s <= i2_s, 1, 0)


make_test("triu", triu, triu_spec)
Correct
(Loading...)

Puzzle 7 - cumsum

Compute cumsum - the cumulative sum.

def cumsum_spec(a, i1_s, i2_s, out):
    total = 0
    for i in range(len(out)):
        out[i] = total + a[i]
        total += a[i]

def cumsum(a: Int32[NamedArray, "i1"], i1_s : Int32[NamedArray, "i1"], i2_s: Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i2"]:
    return contract("i1", where(i1_s <= i2_s, 1, 0), a)

make_test("cumsum", cumsum, cumsum_spec)
Correct
(Loading...)

Puzzle 8 - diff

Compute diff - the running difference.

def diff_spec(a, i_s, out):
    out[0] = a[0]
    for i in range(0, len(out)):
        out[i] = a[i] - a[(i - 1)]

def diff(a: Int32[NamedArray, "i"], i1_s : Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
    return a - a[{"i": i1_s - 1}]

make_test("diff", diff, diff_spec)
Correct
(Loading...)

Puzzle 9 - stack

Compute vstack - the matrix of two vectors

def stack_spec(a, b, out):
    for i in range(len(out[0])):
        out[0][i] = a[i]
        out[1][i] = b[i]

def stack(a: Int32[NamedArray, "i"], b: Int32[NamedArray, "i"]) -> Int32[NamedArray, "j i"]:
    return where(arange("j", 2) == 1, b, a)


make_test("stack", stack, stack_spec, init_size={"j" : 2})
Correct
(Loading...)

Puzzle 10 - roll

Compute roll - the vector shifted 1 circular position.

def roll_spec(a, i_s, out):
    for i in range(len(out)):
        if i + 1 < len(out):
            out[i] = a[i + 1]
        else:
            out[i] = a[i + 1 - len(out)]

def roll(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
    return a[{"i": (i_s + 1) % i_s.named_shape["i"]}]


make_test("roll", roll, roll_spec)
Correct
(Loading...)

Puzzle 11 - flip

Compute flip - the reversed vector

def flip_spec(a, i_s, out):
    for i in range(len(out)):
        out[i] = a[len(out) - i - 1]

def flip(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
    return a[{"i": (-i_s - 1)}]

make_test("flip", flip, flip_spec, add_sizes=["i"])
Correct
(Loading...)

Puzzle 12 - compress

Compute compress - keep only masked entries (left-aligned).

def compress_spec(g, v, i1_s, i2_s, out):
    j = 0
    for i in range(len(out)):
        out[i] = 0
    for i in range(len(g)):
        if g[i] > 1:
            out[j] = v[i]
            j += 1

def compress(g: Int32[NamedArray, "i1"], v: Int32[NamedArray, "i2"], i1_s:Int32[NamedArray, "i1"], i2_s:Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i2"]:
    # I don't know how to do this one!
    return g

make_test("compress", compress, compress_spec)
Failure
(Loading...)

Puzzle 13 - pad_to

Compute pad_to - eliminate or add 0s to change size of vector.

def pad_to_spec(a, i_s, j_s, out):
    for i in range(len(out)):
        if i < len(a):
            out[i] = a[i]
        else:
            out[i] = 0

def pad_to(a: Int32[NamedArray, "i"], i_s:Int32[NamedArray, "i"], j_s:Int32[NamedArray, "j"])  -> Int32[NamedArray, "j"]:
    return contract("i", a, where(j_s == i_s, 1, 0))


make_test("pad_to", pad_to, pad_to_spec)
Correct
(Loading...)

Puzzle 14 - sequence_mask

Compute sequence_mask - pad out to length per batch.

# Didn't do
# def sequence_mask_spec(values, length, out):
#     for i in range(len(out)):
#         for j in range(len(out[0])):
#             if j < length[i]:
#                 out[i][j] = values[i][j]
#             else:
#                 out[i][j] = 0

# def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
#     pass


# def constraint_set_length(d):
#     d["length"] = d["length"] % d["values"].shape[1]
#     return d

# make_test("sequence_mask",
#     sequence_mask, sequence_mask_spec, constraint=constraint_set_length
# )

Puzzle 15 - bincount

Compute bincount - count number of times an entry was seen.

def bincount_spec(a, i_s, j1_s, j2_s, out):
    for i in range(len(out)):
        out[i] = 0
    for i in range(len(a)):
        out[a[i]] += 1

def bincount(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"],
             j1_s: Int32[NamedArray, "j1"], j2_s: Int32[NamedArray, "j2"]) -> Int32[NamedArray, "j2"]:
    return contract("i", eye(j1_s, j2_s)[{"j1": a}])


def constraint_set_max(d):
    d["a"] = d["a"] % d["return"].named_shape["j2"]
    return d


make_test("bincount",
    bincount, bincount_spec, constraint=constraint_set_max
)
Correct
(Loading...)

Puzzle 16 - scatter_add

Compute scatter_add - add together values that link to the same location.

def scatter_add_spec(values, link, j1_s, j2_s, out):
    for i in range(len(out)):
        out[i] = 0
    for j in range(len(values)):
        out[link[j]] += values[j]

def scatter_add(values: Int32[NamedArray, "i"], link: Int32[NamedArray,"i"],
                j1_s: Int32[NamedArray, "j1"], j2_s: Int32[NamedArray, "j2"]) -> Int32[NamedArray, "j2"]:
    return contract("i", values, eye(j1_s, j2_s)[{"j1": link}])


def constraint_set_max(d):
    d["link"] = d["link"] % d["return"].named_shape["j2"]
    return d

make_test("scatter_add",
    scatter_add, scatter_add_spec, constraint=constraint_set_max
)
Correct
(Loading...)
import inspect
fns = (ones, sum, outer, diag, eye, triu, cumsum, diff, stack, roll, flip,
       compress, pad_to,  bincount, scatter_add)

for fn in fns:
    lines = [l for l in inspect.getsource(fn).split("\n") if not l.strip().startswith("#")]

    if len(lines) > 3:
        print(fn.__name__, len(lines[2]), "(more than 1 line)")
    else:
        print(fn.__name__, len(lines[1]))
ones 22
sum 27
outer 16
diag 38
eye 36
triu 36
cumsum 55
diff 33
stack 43
roll 53
flip 31
compress 12
pad_to 52
bincount 52 (more than 1 line)
scatter_add 63 (more than 1 line)